import find_mxnet
import mxnet as mx

data_dir = "mnist/"
data_shape =(784, )
batch_size = 128

train = mx.io.MNISTIter(
            image       = data_dir + "train-images-idx3-ubyte",
            label       = data_dir + "train-labels-idx1-ubyte",
            input_shape = data_shape,
            batch_size  = batch_size,
            shuffle     = True,
            num_parts   = 1,
            part_index  = 0)
n = 1
for data_batch in train:
    n = n+1

print n